import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
import argparse


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    N = len(S)
    
    for j in range(N):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*N)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*N)) )
    
    return ret/(2*N)

def Estimator(S_s, gX, gS, SNR, alpha):
    
    N = len(S_s)
    
    output_X = np.zeros(N)
    output_XX =np.zeros(N)
    
    z = Symbol('z')
    
    dfr = 64
    
    if M >= 3000:
        dfr = 128
    if M >= 5000:
        dfr = 256
    if M >= 7000:
        dfr = 512
        
    for i in range(N):
        
        #### optimal eigenvalue for X
        zz = S_s[i] -  I*np.sqrt(dfr/(2*N))
        gS_eval = gS.subs(z,zz).evalf()
        zeta = gS_eval + ((1-alpha)/alpha)*(1/zz)
        
        Z = (zz/zeta -1)/SNR
        
        Est = gX.subs(z,sqrt(Z)).evalf() + gX.subs(z,-sqrt(Z)).evalf()
        
        output_X[i] = im(((Est/zeta)/(2*SNR*im(gS_eval))).evalf())
        
        #### optimal eigenvalue for X^2
        output_XX[i]  = ( -1 + 1 /( alpha * ( im(gS_eval)**2 + ( re(gS_eval) + (-1 + 1/alpha )/S_s[i] )**2 ) ) )/SNR
    
    return output_X, output_XX


def main():
    
    z = Symbol('z')
    
    p = argparse.ArgumentParser()
    p.add_argument('-M', type=int)
    args = p.parse_args()


    N = 2000
    M = args.M
    a = N/M
    SNR = 5
    
    Ex = 10
    
        
    E_X_oracle = np.zeros(Ex)
    E_X_RIE = np.zeros(Ex)
    E_X_sqXX = np.zeros(Ex)
        
    E_XX_oracle = np.zeros(Ex)
    E_XX_RIE = np.zeros(Ex)

    for i in range(Ex):
        

        X = np.triu(np.random.normal(0, 1, (N,N)))
        X = X + np.transpose(X) + np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
        X = X/np.sqrt(N)
        X = X + 3*np.eye(N)
                
        gX =  (z - 3 - sqrt(z-5)* sqrt(z-1))/2

    
        ## Noise
        Y = np.random.randn(N,M)
        Y = Y/np.sqrt(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)


        ### Observation
        S = np.sqrt(SNR) * X @ Y + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for X & X^2
        e_hat_X_oracle = np.zeros(N)
        e_hat_XX_oracle = np.zeros(N)
            
        XX = X @ X
        
        X_norm = LA.norm(X)**2
        XX_norm = LA.norm(XX)**2
        
        for k in range(N):
            e_hat_X_oracle[k] = np.transpose(U_s[:,k])@X@U_s[:,k]
            e_hat_XX_oracle[k] = np.transpose(U_s[:,k])@XX@U_s[:,k]
                
        X_hat_oracle = U_s@np.diag(e_hat_X_oracle)@np.transpose(U_s)
        XX_hat_oracle = U_s@np.diag(e_hat_XX_oracle)@np.transpose(U_s)
        
        E_X_oracle[i] = ( LA.norm(X-X_hat_oracle)**2) / X_norm
        E_XX_oracle[i] = ( LA.norm(XX-XX_hat_oracle)**2 ) / XX_norm



        #### RIE for X & X^2
        e_hat_X, e_hat_XX = Estimator(S_s, gX, gS, SNR, a)
        
        X_hat = U_s@np.diag(e_hat_X)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(X-X_hat)**2 ) / X_norm
        
        X_hat_sqXX = U_s@np.diag(np.sqrt(e_hat_XX))@np.transpose(U_s)
        E_X_sqXX[i] = ( LA.norm(X-X_hat_sqXX)**2 ) / X_norm
        
        XX_hat = U_s@np.diag(e_hat_XX)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(XX-XX_hat)**2 ) / XX_norm


    filename = 'X-Wigner_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_X_oracle)
    
    filename = 'XX-Wigner_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_XX_oracle)
    
    filename = 'X-Wigner_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_X_RIE)
    
    filename = 'X-Wigner_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_sqXX.npy'
    np.save( filename, E_X_sqXX)

    filename = 'XX-Wigner_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_XX_RIE)

#
if __name__ == "__main__":
    main()
    
